import traceback

from datasets.common import DatasetName
from datasets.frozen_embeddings.loader  import EmbeddedDataset
from experiments.experimental_pipeline import ExperimentalPipeline
import os

from typing import List, Union
VISION_MODELS = ["resnet18", "resnet50", "vit_b_16", "swin_t", "efficientnet_b0"]
# Classic datasets
classic_datasets = ["Diabetes", "CDR"]

# Vision-based embedding datasets
flowers_datasets = [
    EmbeddedDataset(DatasetName.OXFORD_FLOWERS, model=model)
    for model in VISION_MODELS
]

cifar_catdog_datasets = [
    EmbeddedDataset(DatasetName.CIFAR10_CAT_DOG, model=model)
    for model in VISION_MODELS
]

# Remaining CIFAR variants (not cat/dog)
cifar_other = [
    DatasetName.CIFAR10_AUTO_TRUCK,
    # DatasetName.CIFAR10_DEER_HORSE
]

cifar_other_datasets = [
    EmbeddedDataset(name, model=model)
    for name in cifar_other
    for model in VISION_MODELS
]

# Final list
DATASETS: List[Union[str, EmbeddedDataset]] = (
    ["DogFish", "Enron",] +
    cifar_catdog_datasets +
    flowers_datasets +
    classic_datasets +
    cifar_other_datasets
)

if __name__ == "__main__":
    for dataset_name in DATASETS:
        print("\n" + "=" * 80)
        print(f"Starting experiment for dataset: {dataset_name}")
        print("=" * 80)

        # Configure data loader flags for Enron
        dataloader_flags = {}
        if dataset_name == "Enron":
            dataloader_flags["load_legacy"] = True

        try:
            # Initialize and run the experimental pipeline
            experimental_pipeline = ExperimentalPipeline(
                dataset_name=dataset_name,
                verbosity=3,
                dataloader_flags=dataloader_flags,
                storage_allowance_mb=1024 if os.getenv("DATA_DIRECTORY") else 25
            )
            experimental_pipeline.run()
        except KeyboardInterrupt:
            print("\nExperiment interrupted. Exiting.")
            raise
        except Exception as e:
            print(f"\nAn error occurred while running the experiment for {dataset_name}: {e}")
            traceback.print_exc()  # This prints the full stack trace
            continue  # Proceed to the next dataset

